{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.5/dist-packages/sklearn/cross_validation.py:41: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.\n",
      "  \"This module will be removed in 0.20.\", DeprecationWarning)\n"
     ]
    }
   ],
   "source": [
    "import bert_model as modeling\n",
    "import tensorflow as tf\n",
    "import os\n",
    "import numpy as np\n",
    "from utils import *\n",
    "from sklearn.cross_validation import train_test_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['negative', 'positive']\n",
      "10662\n",
      "10662\n"
     ]
    }
   ],
   "source": [
    "trainset = sklearn.datasets.load_files(\n",
    "    container_path = 'data', encoding = 'UTF-8'\n",
    ")\n",
    "trainset.data, trainset.target = separate_dataset(trainset, 1.0)\n",
    "print(trainset.target_names)\n",
    "print(len(trainset.data))\n",
    "print(len(trainset.target))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vocab from size: 20332\n",
      "Most common words [('film', 1453), ('movie', 1270), ('one', 727), ('like', 721), ('story', 477), ('much', 386)]\n",
      "Sample data [534, 2424, 3590, 10310, 36, 9328, 217, 150, 19, 3747] ['rock', 'destined', '21st', 'centurys', 'new', 'conan', 'hes', 'going', 'make', 'splash']\n"
     ]
    }
   ],
   "source": [
    "concat = ' '.join(trainset.data).split()\n",
    "vocabulary_size = len(list(set(concat)))\n",
    "data, count, dictionary, rev_dictionary = build_dataset(concat, vocabulary_size)\n",
    "print('vocab from size: %d'%(vocabulary_size))\n",
    "print('Most common words', count[4:10])\n",
    "print('Sample data', data[:10], [rev_dictionary[i] for i in data[:10]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "GO = dictionary['GO']\n",
    "PAD = dictionary['PAD']\n",
    "EOS = dictionary['EOS']\n",
    "UNK = dictionary['UNK']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "size_layer = 128\n",
    "num_layers = 2\n",
    "embedded_size = 128\n",
    "dimension_output = len(trainset.target_names)\n",
    "learning_rate = 1e-3\n",
    "maxlen = 50\n",
    "batch_size = 128"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "\n",
    "bert_config = modeling.BertConfig(\n",
    "    vocab_size = len(dictionary),\n",
    "    hidden_size = size_layer,\n",
    "    num_hidden_layers = num_layers,\n",
    "    num_attention_heads = size_layer // 4,\n",
    "    intermediate_size = size_layer * 2,\n",
    ")\n",
    "\n",
    "input_ids = tf.placeholder(tf.int32, [None, maxlen])\n",
    "input_mask = tf.placeholder(tf.int32, [None, maxlen])\n",
    "segment_ids = tf.placeholder(tf.int32, [None, maxlen])\n",
    "label_ids = tf.placeholder(tf.int32, [None])\n",
    "is_training = tf.placeholder(tf.bool)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_model(\n",
    "    bert_config,\n",
    "    is_training,\n",
    "    input_ids,\n",
    "    input_mask,\n",
    "    segment_ids,\n",
    "    labels,\n",
    "    num_labels,\n",
    "    use_one_hot_embeddings,\n",
    "    reuse_flag = False,\n",
    "):\n",
    "    model = modeling.BertModel(\n",
    "        config = bert_config,\n",
    "        is_training = is_training,\n",
    "        input_ids = input_ids,\n",
    "        input_mask = input_mask,\n",
    "        token_type_ids = segment_ids,\n",
    "        use_one_hot_embeddings = use_one_hot_embeddings,\n",
    "    )\n",
    "\n",
    "    output_layer = model.get_pooled_output()\n",
    "    hidden_size = output_layer.shape[-1].value\n",
    "    with tf.variable_scope('weights', reuse = reuse_flag):\n",
    "        output_weights = tf.get_variable(\n",
    "            'output_weights',\n",
    "            [num_labels, hidden_size],\n",
    "            initializer = tf.truncated_normal_initializer(stddev = 0.02),\n",
    "        )\n",
    "        output_bias = tf.get_variable(\n",
    "            'output_bias', [num_labels], initializer = tf.zeros_initializer()\n",
    "        )\n",
    "\n",
    "    with tf.variable_scope('loss'):\n",
    "        def apply_dropout_last_layer(output_layer):\n",
    "            output_layer = tf.nn.dropout(output_layer, keep_prob = 0.9)\n",
    "            return output_layer\n",
    "\n",
    "        def not_apply_dropout(output_layer):\n",
    "            return output_layer\n",
    "\n",
    "        output_layer = tf.cond(\n",
    "            is_training,\n",
    "            lambda: apply_dropout_last_layer(output_layer),\n",
    "            lambda: not_apply_dropout(output_layer),\n",
    "        )\n",
    "        logits = tf.matmul(output_layer, output_weights, transpose_b = True)\n",
    "        print(\n",
    "            'output_layer:',\n",
    "            output_layer.shape,\n",
    "            ', output_weights:',\n",
    "            output_weights.shape,\n",
    "            ', logits:',\n",
    "            logits.shape,\n",
    "        )\n",
    "\n",
    "        logits = tf.nn.bias_add(logits, output_bias)\n",
    "        probabilities = tf.nn.softmax(logits)\n",
    "        loss = tf.nn.sparse_softmax_cross_entropy_with_logits(\n",
    "            labels = labels, logits = logits\n",
    "        )\n",
    "        loss = tf.reduce_mean(loss)\n",
    "        correct_pred = tf.equal(tf.argmax(logits, 1, output_type = tf.int32), labels)\n",
    "        accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))\n",
    "\n",
    "        return loss, logits, probabilities, model, accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "output_layer: (?, 128) , output_weights: (2, 128) , logits: (?, 2)\n"
     ]
    }
   ],
   "source": [
    "use_one_hot_embeddings = False\n",
    "loss, logits, probabilities, model, accuracy = create_model(\n",
    "    bert_config,\n",
    "    is_training,\n",
    "    input_ids,\n",
    "    input_mask,\n",
    "    segment_ids,\n",
    "    label_ids,\n",
    "    dimension_output,\n",
    "    use_one_hot_embeddings,\n",
    ")\n",
    "global_step = tf.Variable(0, trainable = False, name = 'Global_Step')\n",
    "optimizer = tf.contrib.layers.optimize_loss(\n",
    "    loss,\n",
    "    global_step = global_step,\n",
    "    learning_rate = learning_rate,\n",
    "    optimizer = 'Adam',\n",
    "    clip_gradients = 3.0,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "vectors = str_idx(trainset.data, dictionary, maxlen)\n",
    "train_X, test_X, train_Y, test_Y = train_test_split(\n",
    "    vectors, trainset.target, test_size = 0.2\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 67/67 [00:06<00:00, 12.10it/s, accuracy=0.593, cost=0.69] \n",
      "test minibatch loop: 100%|██████████| 17/17 [00:00<00:00, 31.92it/s, accuracy=0.671, cost=0.619]\n",
      "train minibatch loop:   3%|▎         | 2/67 [00:00<00:05, 11.60it/s, accuracy=0.703, cost=0.584]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, pass acc: 0.000000, current acc: 0.660026\n",
      "time taken: 6.548045873641968\n",
      "epoch: 0, training loss: 0.692587, training acc: 0.534277, valid loss: 0.655280, valid acc: 0.660026\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 67/67 [00:05<00:00, 11.63it/s, accuracy=0.914, cost=0.202]\n",
      "test minibatch loop: 100%|██████████| 17/17 [00:00<00:00, 34.19it/s, accuracy=0.788, cost=0.518]\n",
      "train minibatch loop:   3%|▎         | 2/67 [00:00<00:05, 11.64it/s, accuracy=0.938, cost=0.158]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1, pass acc: 0.660026, current acc: 0.760851\n",
      "time taken: 6.261883974075317\n",
      "epoch: 1, training loss: 0.437358, training acc: 0.802432, valid loss: 0.633040, valid acc: 0.760851\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 67/67 [00:05<00:00, 11.62it/s, accuracy=0.988, cost=0.0401]\n",
      "test minibatch loop: 100%|██████████| 17/17 [00:00<00:00, 34.41it/s, accuracy=0.765, cost=0.696]\n",
      "train minibatch loop:   3%|▎         | 2/67 [00:00<00:05, 11.49it/s, accuracy=0.992, cost=0.043]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 6.260380268096924\n",
      "epoch: 2, training loss: 0.146797, training acc: 0.954557, valid loss: 0.777705, valid acc: 0.737404\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 67/67 [00:05<00:00, 11.61it/s, accuracy=0.988, cost=0.0312]\n",
      "test minibatch loop: 100%|██████████| 17/17 [00:00<00:00, 33.79it/s, accuracy=0.635, cost=0.959]\n",
      "train minibatch loop:   3%|▎         | 2/67 [00:00<00:05, 11.64it/s, accuracy=0.992, cost=0.0209]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 6.278367280960083\n",
      "epoch: 3, training loss: 0.076096, training acc: 0.983048, valid loss: 0.990929, valid acc: 0.724950\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 67/67 [00:05<00:00, 11.64it/s, accuracy=1, cost=0.00297]   \n",
      "test minibatch loop: 100%|██████████| 17/17 [00:00<00:00, 34.72it/s, accuracy=0.694, cost=1.31] \n",
      "train minibatch loop:   3%|▎         | 2/67 [00:00<00:05, 11.13it/s, accuracy=0.984, cost=0.0425]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 6.247540712356567\n",
      "epoch: 4, training loss: 0.040630, training acc: 0.995310, valid loss: 1.248237, valid acc: 0.725667\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 67/67 [00:05<00:00, 11.55it/s, accuracy=1, cost=0.00343]    \n",
      "test minibatch loop: 100%|██████████| 17/17 [00:00<00:00, 34.33it/s, accuracy=0.694, cost=1.38]\n",
      "train minibatch loop:   3%|▎         | 2/67 [00:00<00:05, 11.55it/s, accuracy=1, cost=0.00337]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 6.31015419960022\n",
      "epoch: 5, training loss: 0.027960, training acc: 0.999297, valid loss: 1.295197, valid acc: 0.738794\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 67/67 [00:05<00:00, 11.58it/s, accuracy=1, cost=0.00415]   \n",
      "test minibatch loop: 100%|██████████| 17/17 [00:00<00:00, 34.38it/s, accuracy=0.741, cost=1.19] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 6.283868312835693\n",
      "epoch: 6, training loss: 0.021000, training acc: 1.001290, valid loss: 1.297261, valid acc: 0.744431\n",
      "\n",
      "break epoch:7\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "import time\n",
    "\n",
    "EARLY_STOPPING, CURRENT_CHECKPOINT, CURRENT_ACC, EPOCH = 5, 0, 0, 0\n",
    "\n",
    "while True:\n",
    "    lasttime = time.time()\n",
    "    if CURRENT_CHECKPOINT == EARLY_STOPPING:\n",
    "        print('break epoch:%d\\n' % (EPOCH))\n",
    "        break\n",
    "\n",
    "    train_acc, train_loss, test_acc, test_loss = 0, 0, 0, 0\n",
    "    pbar = tqdm(\n",
    "        range(0, len(train_X), batch_size), desc = 'train minibatch loop'\n",
    "    )\n",
    "    for i in pbar:\n",
    "        batch_x = train_X[i : min(i + batch_size, train_X.shape[0])]\n",
    "        batch_y = train_Y[i : min(i + batch_size, train_X.shape[0])]\n",
    "        np_mask = np.ones((len(batch_x), maxlen), dtype = np.int32)\n",
    "        np_segment = np.ones((len(batch_x), maxlen), dtype = np.int32)\n",
    "        acc, cost, _ = sess.run(\n",
    "            [accuracy, loss, optimizer],\n",
    "            feed_dict = {\n",
    "                input_ids: batch_x,\n",
    "                label_ids: batch_y,\n",
    "                input_mask: np_mask,\n",
    "                segment_ids: np_segment,\n",
    "                is_training: True\n",
    "            },\n",
    "        )\n",
    "        assert not np.isnan(cost)\n",
    "        train_loss += cost\n",
    "        train_acc += acc\n",
    "        pbar.set_postfix(cost = cost, accuracy = acc)\n",
    "\n",
    "    pbar = tqdm(range(0, len(test_X), batch_size), desc = 'test minibatch loop')\n",
    "    for i in pbar:\n",
    "        batch_x = test_X[i : min(i + batch_size, test_X.shape[0])]\n",
    "        batch_y = test_Y[i : min(i + batch_size, test_X.shape[0])]\n",
    "        np_mask = np.ones((len(batch_x), maxlen), dtype = np.int32)\n",
    "        np_segment = np.ones((len(batch_x), maxlen), dtype = np.int32)\n",
    "        acc, cost = sess.run(\n",
    "            [accuracy, loss],\n",
    "            feed_dict = {\n",
    "                input_ids: batch_x,\n",
    "                label_ids: batch_y,\n",
    "                input_mask: np_mask,\n",
    "                segment_ids: np_segment,\n",
    "                is_training: False\n",
    "            },\n",
    "        )\n",
    "        test_loss += cost\n",
    "        test_acc += acc\n",
    "        pbar.set_postfix(cost = cost, accuracy = acc)\n",
    "\n",
    "    train_loss /= len(train_X) / batch_size\n",
    "    train_acc /= len(train_X) / batch_size\n",
    "    test_loss /= len(test_X) / batch_size\n",
    "    test_acc /= len(test_X) / batch_size\n",
    "\n",
    "    if test_acc > CURRENT_ACC:\n",
    "        print(\n",
    "            'epoch: %d, pass acc: %f, current acc: %f'\n",
    "            % (EPOCH, CURRENT_ACC, test_acc)\n",
    "        )\n",
    "        CURRENT_ACC = test_acc\n",
    "        CURRENT_CHECKPOINT = 0\n",
    "    else:\n",
    "        CURRENT_CHECKPOINT += 1\n",
    "\n",
    "    print('time taken:', time.time() - lasttime)\n",
    "    print(\n",
    "        'epoch: %d, training loss: %f, training acc: %f, valid loss: %f, valid acc: %f\\n'\n",
    "        % (EPOCH, train_loss, train_acc, test_loss, test_acc)\n",
    "    )\n",
    "    EPOCH += 1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "validation minibatch loop: 100%|██████████| 17/17 [00:00<00:00, 32.90it/s]\n"
     ]
    }
   ],
   "source": [
    "real_Y, predict_Y = [], []\n",
    "\n",
    "pbar = tqdm(\n",
    "    range(0, len(test_X), batch_size), desc = 'validation minibatch loop'\n",
    ")\n",
    "for i in pbar:\n",
    "    batch_x = test_X[i : min(i + batch_size, test_X.shape[0])]\n",
    "    batch_y = test_Y[i : min(i + batch_size, test_X.shape[0])]\n",
    "    np_mask = np.ones((len(batch_x), maxlen), dtype = np.int32)\n",
    "    np_segment = np.ones((len(batch_x), maxlen), dtype = np.int32)\n",
    "    predict_Y += np.argmax(\n",
    "        sess.run(\n",
    "            logits,\n",
    "            feed_dict = {\n",
    "                input_ids: batch_x,\n",
    "                label_ids: batch_y,\n",
    "                input_mask: np_mask,\n",
    "                segment_ids: np_segment,\n",
    "                is_training: False,\n",
    "            },\n",
    "        ),\n",
    "        1,\n",
    "    ).tolist()\n",
    "    real_Y += batch_y\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "             precision    recall  f1-score   support\n",
      "\n",
      "   negative       0.72      0.72      0.72      1044\n",
      "   positive       0.73      0.74      0.74      1089\n",
      "\n",
      "avg / total       0.73      0.73      0.73      2133\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from sklearn import metrics\n",
    "print(metrics.classification_report(real_Y, predict_Y, target_names = ['negative','positive']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
